{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "Convert MelGAN generator from pytorch to tensorflow",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.4"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NtibXctgCmhV"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kan-bayashi/ParallelWaveGAN/blob/master/notebooks/convert_melgan_from_pytorch_to_tensorflow.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YPtDpfkQ9R8G"
      },
      "source": [
        "# Convert MelGAN generator from pytorch to tensorflow\n",
        "\n",
        "This notebook proivdies the procedure of conversion of MelGAN generator from pytorch to tensorflow.  \n",
        "Tensorflow version can accelerate the inference speed on both CPU and GPU."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "dB-N-PJU9Txg",
        "outputId": "e6721858-9f61-4006-9b2f-40b45f8d9b3e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 343
        }
      },
      "source": [
        "# install libraries for google colab\n",
        "!git clone https://github.com/kan-bayashi/ParallelWaveGAN.git\n",
        "!cd ParallelWaveGAN; pip install -qq .\n",
        "!pip install -qq tensorflow-gpu==2.1"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'ParallelWaveGAN'...\n",
            "remote: Enumerating objects: 4, done.\u001b[K\n",
            "remote: Counting objects: 100% (4/4), done.\u001b[K\n",
            "remote: Compressing objects: 100% (4/4), done.\u001b[K\n",
            "remote: Total 3565 (delta 0), reused 0 (delta 0), pack-reused 3561\u001b[K\n",
            "Receiving objects: 100% (3565/3565), 23.89 MiB | 40.70 MiB/s, done.\n",
            "Resolving deltas: 100% (1961/1961), done.\n",
            "\u001b[K     |████████████████████████████████| 1.6MB 4.7MB/s \n",
            "\u001b[K     |████████████████████████████████| 204kB 49.8MB/s \n",
            "\u001b[?25h  Building wheel for parallel-wavegan (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for librosa (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for kaldiio (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[K     |████████████████████████████████| 421.8MB 40kB/s \n",
            "\u001b[K     |████████████████████████████████| 450kB 51.4MB/s \n",
            "\u001b[?25h  Building wheel for gast (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[31mERROR: tensorflow 2.2.0rc1 has requirement gast==0.3.3, but you'll have gast 0.2.2 which is incompatible.\u001b[0m\n",
            "\u001b[31mERROR: tensorflow 2.2.0rc1 has requirement tensorflow-estimator<2.3.0,>=2.2.0rc0, but you'll have tensorflow-estimator 2.1.0 which is incompatible.\u001b[0m\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "EwYMtOUs9R8I",
        "colab": {}
      },
      "source": [
        "import os\n",
        "import numpy as np\n",
        "import torch\n",
        "import tensorflow as tf\n",
        "import yaml\n",
        "from parallel_wavegan.models import MelGANGenerator\n",
        "from parallel_wavegan.models.tf_models import TFMelGANGenerator"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7DWtwgKZ9R8K"
      },
      "source": [
        "## Define Tensorflow and Pytorch models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wevXSH0n9R8L",
        "colab": {}
      },
      "source": [
        "# load vocoder config \n",
        "vocoder_conf = 'ParallelWaveGAN/egs/ljspeech/voc1/conf/melgan.v1.long.yaml'\n",
        "with open(vocoder_conf) as f:\n",
        "    config = yaml.load(f, Loader=yaml.Loader)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XOK6AuWW9R8N",
        "outputId": "a60edb70-d9d0-4fe3-bf0f-7b0b2bda4fe9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 307
        }
      },
      "source": [
        "# define Tensorflow MelGAN generator\n",
        "tf.compat.v1.disable_eager_execution()\n",
        "inputs = tf.keras.Input(batch_shape=[None, None, 80], dtype=tf.float32)\n",
        "audio = TFMelGANGenerator(**config[\"generator_params\"])(inputs)\n",
        "tf_melgan = tf.keras.models.Model(inputs, audio)\n",
        "tf_melgan.summary()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "If using Keras pass *_constraint arguments to layers.\n",
            "Model: \"model\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "input_1 (InputLayer)         [(None, None, 80)]        0         \n",
            "_________________________________________________________________\n",
            "tf_mel_gan_generator (TFMelG (None, None, 1)           4260257   \n",
            "=================================================================\n",
            "Total params: 4,260,257\n",
            "Trainable params: 4,260,257\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wESkPTHZ9R8Q",
        "colab": {}
      },
      "source": [
        "# define pytorch model\n",
        "pytorch_melgan = MelGANGenerator(**config[\"generator_params\"])\n",
        "pytorch_melgan.remove_weight_norm()  # needed since TFMelGANGenerator does not support weight norm\n",
        "pytorch_melgan = pytorch_melgan.to(\"cpu\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "q2PNLx9m9R8S",
        "outputId": "f5ecd5a8-fffe-44e2-c425-d2596ae4a85f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        }
      },
      "source": [
        "# check the number of variables are the same\n",
        "state_dict = pytorch_melgan.state_dict()\n",
        "tf_vars = tf.compat.v1.global_variables()\n",
        "print(\"Number Tensorflow variables: \", len(tf_vars))\n",
        "print(\"Number Pytorch variables: \", len(state_dict.keys()))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Number Tensorflow variables:  84\n",
            "Number Pytorch variables:  84\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "I4Ksgn1o9R8U"
      },
      "source": [
        "## Convert parameters from pytorch to tensorflow"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "JzhiUmpv9R8V",
        "colab": {}
      },
      "source": [
        "def reorder_tf_vars(tf_vars):\n",
        "    \"\"\"\n",
        "    Reorder tensorflow variables to match with pytorch state dict order. \n",
        "    Since each tensorflow layer's order is bias -> weight while pytorch's \n",
        "    one is weight -> bias, we change the order of variables.\n",
        "    \"\"\"\n",
        "    tf_new_var = []\n",
        "    for i in range(0, len(tf_vars), 2):\n",
        "        tf_new_var.append(tf_vars[i + 1])\n",
        "        tf_new_var.append(tf_vars[i])\n",
        "    return tf_new_var"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "cono0O8n9R8X",
        "colab": {}
      },
      "source": [
        "# change the order of variables to be the same as pytorch\n",
        "tf_vars = reorder_tf_vars(tf_vars)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qNC78Bgy9R8Z",
        "colab": {}
      },
      "source": [
        "def convert_weights_pytorch_to_tensorflow(weights_pytorch):\n",
        "    \"\"\"\n",
        "    Convert pytorch Conv1d weight variable to tensorflow Conv2D weights.\n",
        "    Pytorch (f_output, f_input, kernel_size) -> TF (kernel_size, f_input, 1, f_output)\n",
        "    \"\"\"\n",
        "    weights_tensorflow = np.transpose(weights_pytorch, (0,2,1))  # [f_output, kernel_size, f_input]\n",
        "    weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2))  # [kernel-size, f_output, f_input]\n",
        "    weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1))  # [kernel-size, f_input, f_output]\n",
        "    weights_tensorflow = np.expand_dims(weights_tensorflow, 1)  # [kernel-size, f_input, 1, f_output]\n",
        "    return weights_tensorflow"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Hkciz2H49R8b",
        "colab": {}
      },
      "source": [
        "# convert pytorch's variables to tensorflow's one\n",
        "for i, var_name in enumerate(state_dict):\n",
        "    try:\n",
        "        tf_name = tf_vars[i]\n",
        "        torch_tensor = state_dict[var_name].numpy()\n",
        "        if torch_tensor.ndim >= 2:\n",
        "            tensorflow_tensor = convert_weights_pytorch_to_tensorflow(torch_tensor)\n",
        "        else:\n",
        "            tensorflow_tensor = torch_tensor\n",
        "        tf.keras.backend.set_value(tf_name, tensorflow_tensor)\n",
        "    except:\n",
        "        print(tf_name)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VL0YFaji9R8d"
      },
      "source": [
        "## Check both outputs are almost the equal"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "e8FAomYs9R8e",
        "colab": {}
      },
      "source": [
        "fake_mels = np.random.sample((1, 80, 250)).astype(np.float32)\n",
        "with torch.no_grad():\n",
        "    y_pytorch = pytorch_melgan(torch.Tensor(fake_mels))\n",
        "y_tensorflow = tf_melgan.predict(np.transpose(fake_mels, (0, 2, 1)))\n",
        "np.testing.assert_almost_equal(\n",
        "    y_pytorch[0, 0, :].numpy(),\n",
        "    y_tensorflow[0, :, 0],\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jy0c0qv39R8g"
      },
      "source": [
        "## Save Tensorflow and Pytorch models for benchmark"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ItFZ65Jn9R8g",
        "outputId": "dc573931-0485-44e4-b230-dbd37c0b6d6e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        }
      },
      "source": [
        "os.makedirs(\"./checkpoint/tensorflow_generator/\", exist_ok=True)\n",
        "os.makedirs(\"./checkpoint/pytorch_generator/\", exist_ok=True)\n",
        "tf.saved_model.save(tf_melgan, \"./checkpoint/tensorflow_generator/\")\n",
        "torch.save(pytorch_melgan.state_dict(), \"./checkpoint/pytorch_generator/checkpoint.pkl\")"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: ./checkpoint/tensorflow_generator/assets\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Assets written to: ./checkpoint/tensorflow_generator/assets\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0BsqIdyu9R8i"
      },
      "source": [
        "## Inference speed benchmark on GPU\n",
        "\n",
        "From here, we will compare the inference speed using pytorch model and converted tensorflow model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "FU_NIGsu9R8j",
        "colab": {}
      },
      "source": [
        "# To enable eager mode, we need to restart the runtime\n",
        "import os\n",
        "os._exit(00)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_PbcqANV9R8l",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import yaml\n",
        "import tensorflow as tf\n",
        "from tensorflow.python.framework import convert_to_constants\n",
        "from tensorflow.python.saved_model import signature_constants\n",
        "from tensorflow.python.saved_model import tag_constants\n",
        "from parallel_wavegan.models import MelGANGenerator"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "gg-E2hePEwTB",
        "colab": {}
      },
      "source": [
        "# setup pytorch model\n",
        "vocoder_conf = 'ParallelWaveGAN/egs/ljspeech/voc1/conf/melgan.v1.long.yaml'\n",
        "with open(vocoder_conf) as f:\n",
        "    config = yaml.load(f, Loader=yaml.Loader)\n",
        "pytorch_melgan = MelGANGenerator(**config[\"generator_params\"])\n",
        "pytorch_melgan.remove_weight_norm()\n",
        "pytorch_melgan.load_state_dict(torch.load(\n",
        "    \"./checkpoint/pytorch_generator/checkpoint.pkl\", map_location=\"cpu\"))\n",
        "pytorch_melgan = pytorch_melgan.to(\"cuda\").eval()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "OVUMAHuT9R8n",
        "colab": {}
      },
      "source": [
        "# setup tensorflow model\n",
        "class TFMelGAN(object):\n",
        "    def __init__(self, saved_path):\n",
        "        self.saved_path = saved_path\n",
        "        self.graph = self._load_model()\n",
        "        self.mels = None\n",
        "        self.audios = None\n",
        "    \n",
        "    def _load_model(self):\n",
        "        saved_model_loaded = tf.saved_model.load(\n",
        "            self.saved_path, tags=[tag_constants.SERVING])\n",
        "        graph_func = saved_model_loaded.signatures[\n",
        "            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
        "        graph_func = convert_to_constants.convert_variables_to_constants_v2(graph_func)\n",
        "        return graph_func\n",
        "\n",
        "    def set_mels(self, values):\n",
        "        self.mels = tf.identity(tf.constant(values))\n",
        "\n",
        "    def get_mels(self):\n",
        "        return self.mels\n",
        "\n",
        "    def get_audio(self):\n",
        "        return self.audios\n",
        "\n",
        "    def run_inference(self):\n",
        "        self.audios = self.graph(self.mels)[0]\n",
        "        return self.audios   \n",
        "    \n",
        "tf_melgan = TFMelGAN(saved_path='./checkpoint/tensorflow_generator/')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XC7Y0H7r9R8p",
        "colab": {}
      },
      "source": [
        "# warmup\n",
        "fake_mels = np.random.sample((4, 1500, 80)).astype(np.float32)\n",
        "tf_melgan.set_mels(fake_mels)\n",
        "fake_mels = torch.Tensor(fake_mels).transpose(2, 1).to(\"cuda\")\n",
        "with torch.no_grad():\n",
        "    y = pytorch_melgan(fake_mels)\n",
        "y = tf_melgan.run_inference()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CrgQbMYWXDQy",
        "colab_type": "code",
        "outputId": "3c7c0a23-39f0-458a-cf34-4a47a705802a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 305
        }
      },
      "source": [
        "!nvidia-smi"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sun Mar 29 12:57:18 2020       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 440.64.00    Driver Version: 418.67       CUDA Version: 10.1     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   35C    P0    37W / 250W |   5903MiB / 16280MiB |      0%      Default |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                       GPU Memory |\n",
            "|  GPU       PID   Type   Process name                             Usage      |\n",
            "|=============================================================================|\n",
            "+-----------------------------------------------------------------------------+\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "GhKgA5pk9R8r",
        "outputId": "945ccc65-0bfd-4a7e-9db0-1f6df7c4ab5b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        }
      },
      "source": [
        "%%time\n",
        "# check pytorch inference speed\n",
        "with torch.no_grad():\n",
        "    y = pytorch_melgan(fake_mels)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "CPU times: user 8.71 ms, sys: 275 µs, total: 8.99 ms\n",
            "Wall time: 11.5 ms\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "vYZR0ihS9R8s",
        "outputId": "16f22663-acb0-44cf-f84a-38dfd0df406b",
        "scrolled": true,
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        }
      },
      "source": [
        "%%time\n",
        "# check tensorflow inference speed\n",
        "y = tf_melgan.run_inference()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "CPU times: user 6.62 ms, sys: 918 µs, total: 7.54 ms\n",
            "Wall time: 10.9 ms\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}
